'''Novel-View Human Action Synthesis

Implementation of the I3D model:
    Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset. In Proc. CVPR, 2017
    code reference: https://github.com/piergiaj/pytorch-i3d

'''


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np

import os
import sys
from collections import OrderedDict, namedtuple


class MaxPool3dSamePadding(nn.MaxPool3d):

  def compute_pad(self, dim, s):
    if s % self.stride[dim] == 0:
      return max(self.kernel_size[dim] - self.stride[dim], 0)
    else:
      return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)

  def forward(self, x):
    (batch, channel, t, h, w) = x.size()
    out_t = np.ceil(float(t) / float(self.stride[0]))
    out_h = np.ceil(float(h) / float(self.stride[1]))
    out_w = np.ceil(float(w) / float(self.stride[2]))
    pad_t = self.compute_pad(0, t)
    pad_h = self.compute_pad(1, h)
    pad_w = self.compute_pad(2, w)

    pad_t_f = pad_t // 2
    pad_t_b = pad_t - pad_t_f
    pad_h_f = pad_h // 2
    pad_h_b = pad_h - pad_h_f
    pad_w_f = pad_w // 2
    pad_w_b = pad_w - pad_w_f

    pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
    x = F.pad(x, pad)
    return super(MaxPool3dSamePadding, self).forward(x)


class Unit3D(nn.Module):

  def __init__(self, in_channels,
        output_channels,
        kernel_shape=(1, 1, 1),
        stride=(1, 1, 1),
        padding=0,
        activation_fn=F.relu,
        use_batch_norm=True,
        use_bias=False,
        name='unit_3d'):

    """Initializes Unit3D module."""
    super(Unit3D, self).__init__()

    self._output_channels = output_channels
    self._kernel_shape = kernel_shape
    self._stride = stride
    self._use_batch_norm = use_batch_norm
    self._activation_fn = activation_fn
    self._use_bias = use_bias
    self.name = name
    self.padding = padding

    self.conv3d = nn.Conv3d(in_channels=in_channels,
                out_channels=self._output_channels,
                kernel_size=self._kernel_shape,
                stride=self._stride,
                padding=0,
                bias=self._use_bias)

    if self._use_batch_norm:
      self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01)

  def compute_pad(self, dim, s):
    if s % self._stride[dim] == 0:
      return max(self._kernel_shape[dim] - self._stride[dim], 0)
    else:
      return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)


  def forward(self, x):
    (batch, channel, t, h, w) = x.size()
    out_t = np.ceil(float(t) / float(self._stride[0]))
    out_h = np.ceil(float(h) / float(self._stride[1]))
    out_w = np.ceil(float(w) / float(self._stride[2]))
    pad_t = self.compute_pad(0, t)
    pad_h = self.compute_pad(1, h)
    pad_w = self.compute_pad(2, w)

    pad_t_f = pad_t // 2
    pad_t_b = pad_t - pad_t_f
    pad_h_f = pad_h // 2
    pad_h_b = pad_h - pad_h_f
    pad_w_f = pad_w // 2
    pad_w_b = pad_w - pad_w_f

    pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
    x = F.pad(x, pad)

    x = self.conv3d(x)
    if self._use_batch_norm:
      x = self.bn(x)
    if self._activation_fn is not None:
      x = self._activation_fn(x)
    return x



class InceptionModule(nn.Module):
  def __init__(self, in_channels, out_channels, name):
    super(InceptionModule, self).__init__()

    self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
            name=name+'/Branch_0/Conv3d_0a_1x1')
    self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
             name=name+'/Branch_1/Conv3d_0a_1x1')
    self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
             name=name+'/Branch_1/Conv3d_0b_3x3')
    self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
             name=name+'/Branch_2/Conv3d_0a_1x1')
    self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
             name=name+'/Branch_2/Conv3d_0b_3x3')
    self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
                stride=(1, 1, 1), padding=0)
    self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
             name=name+'/Branch_3/Conv3d_0b_1x1')
    self.name = name

  def forward(self, x):
    b0 = self.b0(x)
    b1 = self.b1b(self.b1a(x))
    b2 = self.b2b(self.b2a(x))
    b3 = self.b3b(self.b3a(x))
    return torch.cat([b0,b1,b2,b3], dim=1)


class InceptionI3d(nn.Module):
  """Inception-v1 I3D architecture.
  The model is introduced in:
    Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
    Joao Carreira, Andrew Zisserman
    https://arxiv.org/pdf/1705.07750v1.pdf.
  See also the Inception architecture, introduced in:
    Going deeper with convolutions
    Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
    Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
    http://arxiv.org/pdf/1409.4842v1.pdf.
  """

  # Endpoints of the model in order. During construction, all the endpoints up
  # to a designated `final_endpoint` are returned in a dictionary as the
  # second return value.
  VALID_ENDPOINTS = (
    'Conv3d_1a_7x7',
    'MaxPool3d_2a_3x3',
    'Conv3d_2b_1x1',
    'Conv3d_2c_3x3',
    'MaxPool3d_3a_3x3',
    'Mixed_3b',
    'Mixed_3c',
    'MaxPool3d_4a_3x3',
    'Mixed_4b',
    'Mixed_4c',
    'Mixed_4d',
    'Mixed_4e',
    'Mixed_4f',
    'MaxPool3d_5a_2x2',
    'Mixed_5b',
    'Mixed_5c',
    'Logits',
    'Predictions',
  )

  PERCEPTUAL_ENDPOINTS = (
    'Conv3d_1a_7x7',
    'Conv3d_2b_1x1',
    'Mixed_3c',
    'Mixed_4b',
  )

  def __init__(self, num_classes=400, spatial_squeeze=True,
        final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5, action=False):
    """Initializes I3D model instance.
    Args:
     num_classes: The number of outputs in the logit layer (default 400, which
       matches the Kinetics dataset).
     spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
       before returning (default True).
     final_endpoint: The model contains many possible endpoints.
       `final_endpoint` specifies the last endpoint for the model to be built
       up to. In addition to the output at `final_endpoint`, all the outputs
       at endpoints up to `final_endpoint` will also be returned, in a
       dictionary. `final_endpoint` must be one of
       InceptionI3d.VALID_ENDPOINTS (default 'Logits').
     name: A string (optional). The name of this module.
    Raises:
     ValueError: if `final_endpoint` is not recognized.
    """

    if final_endpoint not in self.VALID_ENDPOINTS:
      raise ValueError('Unknown final endpoint %s' % final_endpoint)

    super(InceptionI3d, self).__init__()
    self._num_classes = num_classes
    self._spatial_squeeze = spatial_squeeze
    self._final_endpoint = final_endpoint
    self.logits = None
    self.action = action

    if self._final_endpoint not in self.VALID_ENDPOINTS:
      raise ValueError('Unknown final endpoint %s' % self._final_endpoint)

    self.end_points = {}
    end_point = 'Conv3d_1a_7x7'
    self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
                      stride=(2, 2, 2), padding=(3,3,3),  name=name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'MaxPool3d_2a_3x3'
    self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
                              padding=0)
    if self._final_endpoint == end_point: return

    end_point = 'Conv3d_2b_1x1'
    self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
                   name=name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'Conv3d_2c_3x3'
    self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
                   name=name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'MaxPool3d_3a_3x3'
    self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
                              padding=0)
    if self._final_endpoint == end_point: return

    end_point = 'Mixed_3b'
    self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'Mixed_3c'
    self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'MaxPool3d_4a_3x3'
    self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
                              padding=0)
    if self._final_endpoint == end_point: return

    end_point = 'Mixed_4b'
    self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'Mixed_4c'
    self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'Mixed_4d'
    self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'Mixed_4e'
    self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'Mixed_4f'
    self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'MaxPool3d_5a_2x2'
    self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
                              padding=0)
    if self._final_endpoint == end_point: return

    end_point = 'Mixed_5b'
    self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'Mixed_5c'
    self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)
    if self._final_endpoint == end_point: return

    end_point = 'Logits'
    if self.action:
      self.avg_pool = nn.AvgPool3d(kernel_size=[2, 4, 4],
                                  stride=(1, 1, 1))
    else:
      self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
                                  stride=(1, 1, 1))
    self.dropout = nn.Dropout(dropout_keep_prob)
    self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
              kernel_shape=[1, 1, 1],
              padding=0,
              activation_fn=None,
              use_batch_norm=False,
              use_bias=True,
              name='logits')

    self.build()


  def replace_logits(self, num_classes):
    self._num_classes = num_classes
    self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
              kernel_shape=[1, 1, 1],
              padding=0,
              activation_fn=None,
              use_batch_norm=False,
              use_bias=True,
              name='logits')


  def build(self):
    for k in self.end_points.keys():
      self.add_module(k, self.end_points[k])

  def forward(self, x, logits=False, perc=True):
    hs = []
    for end_point in self.VALID_ENDPOINTS:
      if end_point in self.end_points:
        if end_point in ['Logits', 'Predictions'] and not self.action:
          break
        x = self._modules[end_point](x)
      if end_point in self.PERCEPTUAL_ENDPOINTS and (not self.action or perc):
        hs.append(x)

    if logits:
      return self.logits(self.dropout(self.avg_pool(x)))

    #if self.action and not perc:
    if self.action:
      x = self.logits(self.dropout(self.avg_pool(x)))
      if self._spatial_squeeze:
        out = x.squeeze(3).squeeze(2).squeeze(2)
    else:
      i3d_outputs = namedtuple("I3DOutputs", ['h1', 'h2', 'h3', 'h4'])
      out = i3d_outputs(hs[0], hs[1], hs[2], hs[3])

    return out

  def set_requires_grad(self, requires_grad=False):
    for param in self.parameters():
      param.requires_grad = requires_grad
